"""
Analyzes the dump of kafka data, that was created by benchmark.py.
Results (e.g. a merchant's profit and revenue) are saved to a CSV file.
"""

import argparse
import csv
import datetime
import os
import json
from collections import defaultdict

import matplotlib
matplotlib.use('Agg') # required for headless plotting

import matplotlib.pyplot as plt


def load_merchant_id_mapping(directory):
    with open(os.path.join(directory, 'merchant_id_mapping.json')) as file:
        return json.load(file)


def analyze_kafka_dump(directory):
    merchant_id_mapping = load_merchant_id_mapping(directory)

    revenue = defaultdict(float)
    with open(os.path.join(directory, 'kafka', 'buyOffer')) as file:
        for event in json.load(file):
            revenue[event['merchant_id']] += event['amount'] * event['price']

    holding_cost = defaultdict(float)
    with open(os.path.join(directory, 'kafka', 'holding_cost')) as file:
        for event in json.load(file):
            holding_cost[event['merchant_id']] += event['cost']

    order_cost = defaultdict(float)
    with open(os.path.join(directory, 'kafka', 'producer')) as file:
        for event in json.load(file):
            order_cost[event['merchant_id']] += event['billing_amount']

    profit = {merchant_id: revenue[merchant_id] - holding_cost[merchant_id] - order_cost[merchant_id]
              for merchant_id in merchant_id_mapping}

    with open(os.path.join(directory, 'results.csv'), 'w') as file:
        writer = csv.writer(file)
        writer.writerow(['name', 'revenue', 'holding_cost', 'order_cost', 'profit'])
        for merchant_id in sorted(merchant_id_mapping, key=merchant_id_mapping.get):
            writer.writerow([merchant_id_mapping[merchant_id], revenue[merchant_id], holding_cost[merchant_id],
                             order_cost[merchant_id], profit[merchant_id]])

    create_chart(directory, merchant_id_mapping,
        topic='inventory_level', value_name='level', label='Inventory Level',
        filename='inventory_levels.png', drawstyle='steps-post')
    create_chart(directory, merchant_id_mapping,
        topic='profitPerMinute', value_name='profit', label='Profit per Minute', filename='profit_per_minute.png')
    create_chart(directory, merchant_id_mapping,
        topic='revenuePerMinute', value_name='revenue', label='Revenue per Minute', filename='revenue_per_minute.png')
    create_chart(directory, merchant_id_mapping,
        topic='profit', value_name='profit', label='Cumulative Profit', filename='cumulative_profit.png')
    create_chart(directory, merchant_id_mapping,
        topic='cumulativeRevenue', value_name='revenue', label='Cumulative Revenue', filename='cumulative_revenue.png')


def parse_timestamps(events):
    for event in events:
        # TODO: use better conversion; strptime discards timezone
        try:
            event['timestamp'] = datetime.datetime.strptime(event['timestamp'], '%Y-%m-%dT%H:%M:%S.%fZ')
        except ValueError:
            # Dates in topic 'inventory_level' have no milliseconds
            event['timestamp'] = datetime.datetime.strptime(event['timestamp'], '%Y-%m-%dT%H:%M:%SZ')


def create_chart(directory, merchant_id_mapping, topic, value_name, label, filename, **options):
    try:
        input_file = os.path.join(directory, 'kafka', topic)
        events = json.load(open(input_file))
    except FileNotFoundError:
        print('Could not find file', input_file)
        print('Skip generating graph', filename)
        return

    parse_timestamps(events)
    fig, ax = plt.subplots()
    for merchant_id in merchant_id_mapping:
        dates = [event['timestamp'] for event in events if event['merchant_id'] == merchant_id]
        values = [event[value_name] for event in events if event['merchant_id'] == merchant_id]
        # Cannot plot if no events belong to that merchant
        if len(dates) > 0:
            ax.plot(dates, values, label=merchant_id_mapping[merchant_id], **options)
    plt.xlabel('Time')
    plt.ylabel(label)
    fig.legend()
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y-%m-%d %H:%M:%S'))
    fig.autofmt_xdate() # rotate and align dates
    plt.tight_layout() # fit dates into picture
    fig.savefig(os.path.join(directory, filename))


def main():
    parser = argparse.ArgumentParser(description='Analyzes data generated by benchmark.py')
    parser.add_argument('--directory', '-d', type=str, required=True)
    args = parser.parse_args()
    analyze_kafka_dump(args.directory)


if __name__ == '__main__':
    main()
